from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib import patches
from torchvision import io
import torch
import numpy as np
from CONFIG import config
from datalib import build_data_loader, load_data
from utils.utils import count_model_params, load_model
from torchinfo import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = config['data']['dataset_path']
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
- Data VisualizationΒΆ
Here, we investigate the transforms on train dataset. We compare the data in a sequence before and after the transformations being applied.
The transforms are implemented to equally to all the samples in a sequence, preserving the time consistanct of the video sequence.
Each new sequence gets fresh random decisions for augmentation. This is done in MoviC.py file line 49 (self.transforms.reset_sequence(sequence_idx=idx), num_epochs=self.num_epochs) resets the flags for horizontal and vertical flips.
For the sake of having the same transformation on all data in a sequence, we needed to have different seeds per sequence. But at the same time we need a base/stable seed during training the model and initializeing the tensors. To solve this problem, we considered the seed for training as
base_seedwhich is increased by idx - the index of the sequence in the dataset - and also a random value in range(epoch). This ensures both consistancy during data augmentatin per sequence, diversity per sequence/epoch, and also during training procecss for tensors (having a defined base-seed).( This is done in _make_sequence_decisions() function line 41 in utils.transforms.py). Although this approach providesdeterministiccresults for agumentationper_sequence, is itnon-deterministicper run!For each sequence, the choosen augmentation can be aither vertical, horizontal, both, or neither! This is also done in _make_sequence_decisions() function in MoviC.py. We consider independent probabilities for each augmentation:
- should_hflip = random.random() < 0.3: 30% chance of horizontal flip.
- should_vflip = random.random() < 0.7: 70% chance of vertical flip.
Since these are independent, the possible outcomes per sequence are:
Neither: (1-0.3) * (1-0.7) = 70% * 30% = 21% probability.
Horizontal only: 30% * 30% = 9% probability.
Vertical only: 70% * 70% = 49% probability.
Both: 30% * 70% = 21% probability.
# Apply transforms to ENTIRE sequence using transform pipeline
train_dataset = load_data(path, split='train', use_transforms=False, Visualize=True)
train_dataset_transformed = load_data(path, split='train', use_transforms=True, Visualize=True)
[INFO] - TRAIN Data Loaded: Coordinates: 9737, Masks: 9737, RGB videos: 9737, Flows: 9737 [INFO] - TRAIN Data Loaded: Coordinates: 9737, Masks: 9737, RGB videos: 9737, Flows: 9737
First run: only vertical flip
from utils.visualization import plot_transform_comparison
idx = np.random.randint(0, len(train_dataset))
rgbs_orig, masks_orig, flows_orig, coords_orig = train_dataset[idx]
rgbs_trans, masks_trans, flows_trans, coords_trans = train_dataset_transformed[idx]
plot_transform_comparison(
rgbs_orig, masks_orig, flows_orig, coords_orig,
rgbs_trans, masks_trans, flows_trans, coords_trans,
n_rows=6, sequence_idx=idx
)
Second run: Both horizontal and vertical flips
from utils.visualization import plot_transform_comparison
idx = np.random.randint(0, len(train_dataset))
rgbs_orig, masks_orig, flows_orig, coords_orig = train_dataset[idx]
rgbs_trans, masks_trans, flows_trans, coords_trans = train_dataset_transformed[idx]
plot_transform_comparison(
rgbs_orig, masks_orig, flows_orig, coords_orig,
rgbs_trans, masks_trans, flows_trans, coords_trans,
n_rows=6, sequence_idx=idx
)
- Datalloaders and modality shapesΒΆ
val_dataset = load_data(path, split='validation', use_transforms=True)
# train_loader= build_data_loader(train_dataset, split='train')
val_loader = build_data_loader(val_dataset, split='validation')
[INFO] - VALIDATION Data Loaded: Coordinates: 250, Masks: 250, RGB videos: 250, Flows: 250
## Verifying the dataloader
rgbs, masks, flows, coords = next(iter(val_loader))
# Send all tensors to device
rgbs = rgbs.to(device)
flows = flows.to(device)
# Move all mask tensors to device
for k in masks:
masks[k] = masks[k].to(device)
# Move all coords tensors to device
for k in coords:
coords[k] = coords[k].to(device)
print(f"RGBs shape: {rgbs.shape}\nFlows shape: {flows.shape}\nMasks shape: {masks['masks'].shape} \nCoords com shape: {coords['com'].shape}\nCoords bbxs shape: {coords['bbox'].shape}")
RGBs shape: torch.Size([32, 24, 3, 128, 128]) Flows shape: torch.Size([32, 24, 3, 128, 128]) Masks shape: torch.Size([32, 24, 128, 128]) Coords com shape: torch.Size([32, 24, 11, 2]) Coords bbxs shape: torch.Size([32, 24, 11, 4])
- Utility visualizationΒΆ
Here, we will visulize some of the helper fucntions used during training
- During object-centric scene representation learning, each token is an object image. Therefore ,we need to extract object images from frames. We can achieve this goal using either bboxs or mask labels. First we will exploer extracting object frames from masks:
'''
Used during training the object-centric model. when masks labels are used to extract object frames from one image.
The object frames are then used to guide the prediction of the next frame.
'''
def extract_object_specific_frames_from_masks(images, masks, num_objects):
"""
images: Tensor of shape [B, T, C, H, W]
masks: Tensor of shape [B, T, H, W] with int values from 0 to num_objects-1
num_objects: int, number of unique objects (including background if needed)
Returns:
object_frames: Tensor of shape [B, T, num_objects, C, H, W]
"""
B, T, C, H, W = images.shape
device = images.device
# Expand images for each object
object_frames = torch.zeros(B, T, num_objects, C, H, W, device=device, dtype=images.dtype)
for obj_id in range(num_objects):
# Create mask for this object: shape [B, T, 1, H, W]
obj_mask = (masks == obj_id).unsqueeze(2) # [B, T, 1, H, W]
# Broadcast mask to all channels
obj_mask = obj_mask.expand(-1, -1, C, -1, -1) # [B, T, C, H, W]
# Apply mask
object_frames[:, :, obj_id] = images * obj_mask
return object_frames
num_objects = 11
object_frames = extract_object_specific_frames_from_masks(rgbs, masks['masks'], num_objects)
object_frames.shape
torch.Size([32, 24, 11, 3, 128, 128])
- As we can see, we extracted 11 (one background + 10 objects in Movi-C dataset) different "
object_frames" from one image. No we visualize them:
from utils.visualization import plot_object_frames
# Choose the first batch and time step
batch_idx = 3
seq_idx = 11
# object_frames shape: [B, T, num_objects, C, H, W]
# We'll print all object frames for this batch and time step
num_objects = len(object_frames[batch_idx][seq_idx])
plot_object_frames(rgbs, object_frames, batch_idx, seq_idx)
- Now we investigate extracting object frames from bounding boxes:
def extract_object_specific_frames_from_bboxes(images, bboxes):
"""
images: Tensor of shape [B, T, C, H, W]
bboxes: Tensor of shape [B, T, num_objects, 4] (x1, y1, x2, y2) in pixel coordinates
Returns:
object_frames: Tensor of shape [B, T, num_objects, C, H, W]
"""
B, T, C, H, W = images.shape
device = images.device
num_objects = bboxes.shape[2]
# Prepare output tensor
object_frames = torch.zeros(B, T, num_objects, C, H, W, device=device, dtype=images.dtype)
for obj_id in range(num_objects):
for b in range(B):
for t in range(T):
x1, y1, x2, y2 = bboxes[b, t, obj_id]
# Clamp coordinates to image bounds and convert to int
x1 = int(torch.clamp(x1, 0, W-1).item())
y1 = int(torch.clamp(y1, 0, H-1).item())
x2 = int(torch.clamp(x2, 0, W-1).item())
y2 = int(torch.clamp(y2, 0, H-1).item())
# Ensure valid bbox
if x2 > x1 and y2 > y1:
# Copy the region from the image to the corresponding location in object_frames
object_frames[b, t, obj_id, :, y1:y2, x1:x2] = images[b, t, :, y1:y2, x1:x2]
# else: leave as zeros (background)
return object_frames
object_frames = extract_object_specific_frames_from_bboxes(rgbs, coords['bbox'])
object_frames.shape
torch.Size([32, 24, 11, 3, 128, 128])
plot_object_frames(rgbs, object_frames, batch_idx, seq_idx)
PatchifierΒΆ
Another utility function to patchify input images. Only used in the Holistic scene representation training where each patch is considered as a token.
from model.model_utils import Patchifier
BATCH_IDX = 4
seq_len = 5
img = rgbs[BATCH_IDX, seq_len]
plt.figure(figsize=(3, 3))
plt.imshow(img.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
patch_size = config['data']['patch_size'] # num of H and W pixels of each patch
patchifier = Patchifier(patch_size)
patch_data = patchifier(rgbs)
print(f"Patchified Shape: {patch_data.shape}") # (B, seq_len, num_patch_H * num_patch_W, 3 * 32 * 32)
num_patches = patch_data.shape[2] # num_patches = num_patch_H * num_patch_W
print(f"Number of patches: {num_patches}")
print(f"Patch size: {patch_size}")
fig, ax = plt.subplots(1, num_patches)
fig.set_size_inches(3 * num_patches, 3)
for i in range(num_patches):
cur_patch = patch_data[BATCH_IDX, seq_len, i].reshape(3, patch_size, patch_size)
ax[i].imshow(cur_patch.permute(1, 2, 0).cpu().numpy())
ax[i].set_title(f"Patch {i+1}")
ax[i].axis("off")
plt.show()
Patchified Shape: torch.Size([32, 24, 64, 768]) Number of patches: 64 Patch size: 16
ExperimentsΒΆ
All the experiments were run on different servers and machines for uni-bonn. In particular, we used cuda1, cuda2, cuda3, cuda4, cuda6 machines, each of which with one NVIDIA GeForce RTX 4090/3090 processing GPU with 24 gigabites of memory.
1. Holistic scene representationΒΆ
- Holistic Transformer-AutoEncoder ModuleΒΆ
from model.holistic_encoder import HolisticEncoder
from model.holistic_decoder import HolisticDecoder
from model.ocvp import TransformerAutoEncoder
- We used
Full-Transformer based Autoencoderfor the Holistic scene scenario. Each image was patfchified into tokens, added positional embeddings, went through transoformer blocks, and normalized to get the desired embeddigns. During the training process, we experienced different transformer architectures, along with input image sizes (64*64)- right table with moderate parameters,called base- and (128*128) - left table with larget transformer parameters called XL. The configs were as following:
Large model configs and learning process
| Parameter | Value |
|---|---|
| model_name | 02_Holistic_AE_XL |
| batch_size | 32 |
| patch_size | 16 |
| num_workers | 8 |
| num_epochs | 100 |
| warmup_epochs | 5 |
| early_stopping_patience | 10 |
| lr | 0.001 |
| encoder_embed_dim | 512 |
| decoder_embed_dim | 384 |
| max_len | 64 |
| in_out_channels | 3 |
| attn_dim | 128 |
| num_heads | 8 |
| mlp_size | 1024 |
| encoder_depth | 12 |
| decoder_depth | 8 |
| predictor_depth | 8 |
| predictor_embed_dim | 256 |
| residual | true |
Based model configs and learning process
| Parameter | Value |
|---|---|
| model_name | 01_Holistic_AE_Base |
| batch_size | 32 |
| patch_size | 16 |
| num_workers | 8 |
| num_epochs | 100 |
| warmup_epochs | 5 |
| early_stopping_patience | 10 |
| lr | 0.0002 |
| encoder_embed_dim | 128 |
| decoder_embed_dim | 64 |
| max_len | 64 |
| in_out_channels | 3 |
| attn_dim | 64 |
| num_heads | 8 |
| mlp_size | 512 |
| encoder_depth | 6 |
| decoder_depth | 3 |
| predictor_depth | 4 |
| predictor_embed_dim | 128 |
| residual | true |
Loss curves for different experimments:ΒΆ
Due to the experiments being run in different machines, the tensorboard logs from the machines related to that experiment will be provided here.
Note:
our base model also could learn proparly with far less parameters, but in the end the quality of recons with the XL model were better and sharper. So we proceded with this model for predictor training
holistic_encoder = HolisticEncoder()
holistic_decoder = HolisticDecoder()
model = TransformerAutoEncoder(holistic_encoder, holistic_decoder).to(device)
summary(model, input_size= rgbs.shape)
==================================================================================================== Layer (type:depth-idx) Output Shape Param # ==================================================================================================== TransformerAutoEncoder [8, 24, 3, 128, 128] -- ββHolisticEncoder: 1-1 [8, 24, 64, 512] -- β ββSequential: 2-1 [8, 24, 64, 512] -- β β ββLayerNorm: 3-1 [8, 24, 64, 768] 1,536 β β ββLinear: 3-2 [8, 24, 64, 512] 393,728 β ββPositionalEncoding: 2-2 [8, 24, 64, 512] -- β ββSequential: 2-3 [8, 24, 64, 512] -- β β ββTransformerBlock: 3-3 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-4 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-5 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-6 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-7 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-8 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-9 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-10 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-11 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-12 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-13 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-14 [8, 24, 64, 512] 1,314,304 β ββLayerNorm: 2-4 [8, 24, 64, 512] 1,024 ββHolisticDecoder: 1-2 [8, 24, 3, 128, 128] -- β ββLinear: 2-5 [8, 24, 64, 384] 196,992 β ββPositionalEncoding: 2-6 [8, 24, 64, 384] -- β ββSequential: 2-7 -- -- β β ββTransformerBlock: 3-15 [8, 24, 64, 384] 985,984 β β ββTransformerBlock: 3-16 [8, 24, 64, 384] 985,984 β β ββTransformerBlock: 3-17 [8, 24, 64, 384] 985,984 β β ββTransformerBlock: 3-18 [8, 24, 64, 384] 985,984 β β ββTransformerBlock: 3-19 [8, 24, 64, 384] 985,984 β β ββTransformerBlock: 3-20 [8, 24, 64, 384] 985,984 β β ββTransformerBlock: 3-21 [8, 24, 64, 384] 985,984 β β ββTransformerBlock: 3-22 [8, 24, 64, 384] 985,984 β ββLayerNorm: 2-8 [8, 24, 64, 384] 768 β ββLinear: 2-9 [8, 24, 64, 768] 295,680 ==================================================================================================== Total params: 24,549,248 Trainable params: 24,549,248 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 413.45 ==================================================================================================== Input size (MB): 37.75 Forward/backward pass size (MB): 6719.28 Params size (MB): 98.20 Estimated Total Size (MB): 6855.22 ====================================================================================================
# Full forward pass through the model
with torch.no_grad():
encoded_features = holistic_encoder(rgbs)
print("Encoded Features shape:", encoded_features.shape)
recons, loss = holistic_decoder(encoded_features)
print("Reconstructed image shape:", recons.shape)
print(f"Reconstructed images match the original images shape: {recons.shape == rgbs.shape}")
Encoded Features shape: torch.Size([8, 24, 16, 512]) Reconstructed image shape: torch.Size([8, 24, 3, 64, 64]) Reconstructed images match the original images shape: True
Now we will evaluate the ability of our pre-trained autoencoders by running a full-forward pass through network using eval images and visualize the results for some sequences.
path_AE = 'experiments/02_Holistic_AE_XL/checkpoints/best_02_Holistic_AE_XL.pth'
model = load_model(model, mode='AE_inference', path_AE = path_AE)
from utils.visualization import plot_images_vs_recons
# Generate reconstructions
with torch.no_grad():
recons, _ = model(rgbs)
# Plot 5 random sequences, each showing 8 original vs 8 reconstructed frames
plot_images_vs_recons(rgbs, recons)
- NOTE :
As we can see, our autoencoder has efficciently learnt the latent space of the input images and can reconstruct/map any embedding in the given latent space to it's relevant input image.
- PSNR, SSIM, and LPIPS metrics analysis:ΒΆ
from utils.metrics import evaluate_metrics
evaluate_metrics(recons, rgbs)
PSNR Mean: 26.412456512451172
PSNR Framewise: tensor([26.7287, 26.6270, 27.3801, 26.3104, 26.7908, 27.1606, 26.2990, 26.9788,
26.5026, 26.6112, 26.3883, 26.8132, 26.4445, 25.7483, 26.5985, 25.8677,
25.6586, 25.8612, 26.2204, 26.4984, 25.9288, 26.0611, 26.1334, 26.2871],
device='cuda:0')
SSIM Mean: 0.9157991409301758
SSIM Framewise: tensor([0.9209, 0.9205, 0.9219, 0.9152, 0.9161, 0.9146, 0.9131, 0.9188, 0.9141,
0.9140, 0.9158, 0.9166, 0.9153, 0.9162, 0.9202, 0.9115, 0.9147, 0.9133,
0.9133, 0.9164, 0.9152, 0.9152, 0.9133, 0.9129], device='cuda:0')
LPIPS Mean: 0.03445807844400406
LPIPS Framewise: tensor([0.0360, 0.0340, 0.0341, 0.0379, 0.0334, 0.0351, 0.0360, 0.0341, 0.0329,
0.0335, 0.0350, 0.0349, 0.0343, 0.0336, 0.0350, 0.0368, 0.0354, 0.0363,
0.0318, 0.0322, 0.0329, 0.0339, 0.0340, 0.0338], device='cuda:0')
2- The base modelΒΆ
path_AE = 'experiments/01_Holistic_AE_Base/checkpoints/best_01_Holistic_AE_Base.pth'
model = load_model(model, mode='AE_inference', path_AE = path_AE)
summary(model, input_size= rgbs.shape)
==================================================================================================== Layer (type:depth-idx) Output Shape Param # ==================================================================================================== TransformerAutoEncoder [8, 24, 3, 128, 128] -- ββHolisticEncoder: 1-1 [8, 24, 64, 128] -- β ββSequential: 2-1 [8, 24, 64, 128] -- β β ββLayerNorm: 3-1 [8, 24, 64, 768] 1,536 β β ββLinear: 3-2 [8, 24, 64, 128] 98,432 β ββPositionalEncoding: 2-2 [8, 24, 64, 128] -- β ββSequential: 2-3 [8, 24, 64, 128] -- β β ββTransformerBlock: 3-3 [8, 24, 64, 128] 164,992 β β ββTransformerBlock: 3-4 [8, 24, 64, 128] 164,992 β β ββTransformerBlock: 3-5 [8, 24, 64, 128] 164,992 β β ββTransformerBlock: 3-6 [8, 24, 64, 128] 164,992 β β ββTransformerBlock: 3-7 [8, 24, 64, 128] 164,992 β β ββTransformerBlock: 3-8 [8, 24, 64, 128] 164,992 β ββLayerNorm: 2-4 [8, 24, 64, 128] 256 ββHolisticDecoder: 1-2 [8, 24, 3, 128, 128] -- β ββLinear: 2-5 [8, 24, 64, 64] 8,256 β ββPositionalEncoding: 2-6 [8, 24, 64, 64] -- β ββSequential: 2-7 -- -- β β ββTransformerBlock: 3-9 [8, 24, 64, 64] 82,752 β β ββTransformerBlock: 3-10 [8, 24, 64, 64] 82,752 β β ββTransformerBlock: 3-11 [8, 24, 64, 64] 82,752 β ββLayerNorm: 2-8 [8, 24, 64, 64] 128 β ββLinear: 2-9 [8, 24, 64, 768] 49,920 ==================================================================================================== Total params: 1,396,736 Trainable params: 1,396,736 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 22.48 ==================================================================================================== Input size (MB): 37.75 Forward/backward pass size (MB): 1189.09 Params size (MB): 5.59 Estimated Total Size (MB): 1232.42 ====================================================================================================
from utils.visualization import plot_images_vs_recons
# Generate reconstructions
with torch.no_grad():
recons, _ = model(rgbs)
# Plot 5 random sequences, each showing 8 original vs 8 reconstructed frames
plot_images_vs_recons(rgbs, recons)
- PSNR, SSIM, and LPIPS metrics analysis:ΒΆ
from utils.metrics import evaluate_metrics
evaluate_metrics(recons, rgbs)
PSNR Mean: 21.13724136352539
PSNR Framewise: tensor([22.7105, 23.3850, 22.2196, 21.7082, 20.2734, 20.9578, 20.5842, 20.9267,
20.5974, 20.4753, 21.1178, 21.7913, 21.4562, 20.7999, 20.8086, 21.4605,
21.2202, 20.4990, 20.5013, 20.6580, 20.8458, 20.3676, 20.8827, 21.0468],
device='cuda:0')
SSIM Mean: 0.7252386808395386
SSIM Framewise: tensor([0.7582, 0.7554, 0.7433, 0.7292, 0.7126, 0.7124, 0.7161, 0.7119, 0.7105,
0.7053, 0.7140, 0.7222, 0.7255, 0.7237, 0.7231, 0.7261, 0.7281, 0.7220,
0.7208, 0.7227, 0.7281, 0.7289, 0.7306, 0.7350], device='cuda:0')
LPIPS Mean: 0.3330411911010742
LPIPS Framewise: tensor([0.3265, 0.3193, 0.3220, 0.3287, 0.3255, 0.3414, 0.3321, 0.3354, 0.3517,
0.3403, 0.3370, 0.3318, 0.3359, 0.3366, 0.3474, 0.3420, 0.3471, 0.3330,
0.3285, 0.3270, 0.3214, 0.3293, 0.3256, 0.3276], device='cuda:0')
- Performance analysis and comparisonΒΆ
PSNR (Peak Signal-to-Noise Ratio):
Measures how close your prediction is to the ground truth at the pixel level.
Higher is better (means less error/noise).SSIM (Structural Similarity Index):
Measures how well your prediction preserves the structure and details of the ground truth image.
Higher is better (closer to 1 means almost identical structure).LPIPS (Learned Perceptual Image Patch Similarity):
Measures perceptual similarity using a neural networkβcloser to how humans judge images.
Lower is better (0 means visually identical to a human).
| Metric | Description | Higher/Lower is better | XL Model | Base Model | Which is better |
|---|---|---|---|---|---|
| PSNR | Pixel-wise fidelity | Higher | 26.41 | 21.14 | XL |
| SSIM | Structural similarity | Higher | 0.916 | 0.725 | XL |
| LPIPS | Perceptual similarity | Lower | 0.0345 | 0.3330 | XL |
As we can see, the XL model is significantly superior to the Base model in pixel fidelity, structural integrity, and perceptual similarity, producing much more realistic and accurate predictions.
- Holositic Transformer-Predictor ModuleΒΆ
We choose the XL model due to better perfomance and trained the predictor in the sencond phase of training.
from model.holistic_predictor import HolisticTransformerPredictor
from model.predictor_wrapper import PredictorWrapper
from model.ocvp import TransformerPredictor
holistic_predictor = HolisticTransformerPredictor()
holistic_predictor= PredictorWrapper(holistic_predictor)
model = TransformerPredictor(holistic_encoder, holistic_decoder, holistic_predictor, mode='inference').to(device)
summary(model, input_size= rgbs.shape)
========================================================================================================= Layer (type:depth-idx) Output Shape Param # ========================================================================================================= TransformerPredictor [8, 5, 64, 512] 8,381,312 ββHolisticEncoder: 1-1 [8, 24, 64, 512] -- β ββSequential: 2-1 [8, 24, 64, 512] -- β β ββLayerNorm: 3-1 [8, 24, 64, 768] 1,536 β β ββLinear: 3-2 [8, 24, 64, 512] 393,728 β ββPositionalEncoding: 2-2 [8, 24, 64, 512] -- β ββSequential: 2-3 [8, 24, 64, 512] -- β β ββTransformerBlock: 3-3 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-4 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-5 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-6 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-7 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-8 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-9 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-10 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-11 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-12 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-13 [8, 24, 64, 512] 1,314,304 β β ββTransformerBlock: 3-14 [8, 24, 64, 512] 1,314,304 β ββLayerNorm: 2-4 [8, 24, 64, 512] 1,024 ββPredictorWrapper: 1-2 [8, 5, 64, 512] -- β ββHolisticTransformerPredictor: 2-5 [8, 5, 64, 512] -- β β ββLinear: 3-15 [8, 5, 64, 256] 131,328 β β ββPositionalEncoding: 3-16 [8, 5, 64, 256] -- β β ββSequential: 3-17 [8, 5, 64, 256] 5,261,312 β β ββLayerNorm: 3-18 [8, 5, 64, 256] 512 β β ββLinear: 3-19 [8, 5, 64, 512] 131,584 β ββHolisticTransformerPredictor: 2-6 [8, 5, 64, 512] (recursive) β β ββLinear: 3-20 [8, 5, 64, 256] (recursive) β β ββPositionalEncoding: 3-21 [8, 5, 64, 256] -- β β ββSequential: 3-22 [8, 5, 64, 256] (recursive) β β ββLayerNorm: 3-23 [8, 5, 64, 256] (recursive) β β ββLinear: 3-24 [8, 5, 64, 512] (recursive) β ββHolisticTransformerPredictor: 2-7 [8, 5, 64, 512] (recursive) β β ββLinear: 3-25 [8, 5, 64, 256] (recursive) β β ββPositionalEncoding: 3-26 [8, 5, 64, 256] -- β β ββSequential: 3-27 [8, 5, 64, 256] (recursive) β β ββLayerNorm: 3-28 [8, 5, 64, 256] (recursive) β β ββLinear: 3-29 [8, 5, 64, 512] (recursive) β ββHolisticTransformerPredictor: 2-8 [8, 5, 64, 512] (recursive) β β ββLinear: 3-30 [8, 5, 64, 256] (recursive) β β ββPositionalEncoding: 3-31 [8, 5, 64, 256] -- β β ββSequential: 3-32 [8, 5, 64, 256] (recursive) β β ββLayerNorm: 3-33 [8, 5, 64, 256] (recursive) β β ββLinear: 3-34 [8, 5, 64, 512] (recursive) β ββHolisticTransformerPredictor: 2-9 [8, 5, 64, 512] (recursive) β β ββLinear: 3-35 [8, 5, 64, 256] (recursive) β β ββPositionalEncoding: 3-36 [8, 5, 64, 256] -- β β ββSequential: 3-37 [8, 5, 64, 256] (recursive) β β ββLayerNorm: 3-38 [8, 5, 64, 256] (recursive) β β ββLinear: 3-39 [8, 5, 64, 512] (recursive) ========================================================================================================= Total params: 30,073,984 Trainable params: 30,073,984 Non-trainable params: 0 Total mult-adds (Units.MEGABYTES): 536.98 ========================================================================================================= Input size (MB): 37.75 Forward/backward pass size (MB): 6350.18 Params size (MB): 86.77 Estimated Total Size (MB): 6474.70 =========================================================================================================
Loss curves for different experimments:ΒΆ
- We tried many different combinations of experminets for the predictor. The model architecture and embedding sizes of the whole network needed to be the same for predictor training. So we were kind of enforced to stick to the initial model configs for AE during the predictor training. We choose the "best_03_Holistic_Predictor_XL" model which could reach the best results for the sake of visuzalization here
Loading model pre-trained Holistic-Predictor checkpointsΒΆ
path_AE = 'experiments/02_Holistic_AE_XL/checkpoints/best_02_Holistic_AE_XL.pth'
path_predictor = 'experiments/03_Holistic_Predictor_XL/checkpoints/best_03_Holistic_Predictor_XL.pth'
model = load_model(model, mode='inference', path_AE = path_AE, path_predictor=path_predictor)
from utils.visualization import plot_predictor_images
# we Visualize 3 random sequences
idx = [0,3,7]
with torch.no_grad():
for i in idx:
encoded_features = holistic_encoder(rgbs)
preds, loss, input_range, target_range = model(rgbs[i].unsqueeze(0))
recons, _ = model.decoder(preds)
input_images = rgbs[i,input_range[0]:input_range[1]].unsqueeze(0)
target_images = rgbs[i,target_range[0]:target_range[1]].unsqueeze(0)
plot_predictor_images(input_images, target_images, recons)
NoteΒΆ
As we can see, our predictor model could do better in predicting the future frames and providing more meaningful embedding for the decoder module. As our pre-trained Ae module was robust, we believe that this fully related to the predictor module. The results are the best we could get with the current network architecture. Maybe if we considered more deeper transformer blocks for the predictor, on smaller embedding size, out predictor could produce more meaningful predictions. However, given no information from objects in Holistic scenario, we could not expect the predictor to make perfect predictions (we can expect this from our object scentric predictor), but at least we could expect better reconstructions and more meaningful embeddings. However, this needed an end-to-end training with new configs.
2. Object-Centric Scene RepresentationΒΆ
- Object-Centric Transformer-AutoEncoder Module
In this scenario, instead of image patches, each extracted object_frame (as explained above) would be a token.
We ran many experiments for this case!! The fist set of experiments were with linear layers as the encoder input (to get the input embeddings for the transformer module) and also in the decoder output (to reconstruct the input image from decoder embedding). The linear layers impose a large number of parameters to the network! But they are faster to train! Following are our results for OC-AE using this setup
NOTEΒΆ
As we can see the recons were not the best and they remain blurry. We believe this was due to the large compression rate in the encoder layer (say : 3*64*64 ---> encoder_embedding_dim) and also the decompression in the decoder (decoder_embed_dim ---> 3*64*64) which introduced a lot of parameters (nearly 800M parameters needed to do the job!!!!) to the model and made the convergance hard!
For this reason, it made a lot of sense to replace those layers with a respective CNN-Based network. Therefor, we did furture experimens with the hybrid Transformer-CNN architecture network. The results are as follows:
NoteΒΆ
Although we tried many different CNN architectures and experiments- from various simple to complex cnn encoder-decoders to using a mixed combintation of linear encoder and cnn decoder to get better recons- we were unsuccesful get the best recons using CNN networks!! We mostly used mask data labels to extract the object frames. We also tried more expressive loss funtions (mixture of MSE+L1 loss). We found it so hard to reach some well-working combination of cnn and transformer networks for the OC-AE task. For this reason, we proceeded we our best so far OC-AE model. This model had the following configs:
| Parameter | Value |
|---|---|
| model_name | 01_OC_AE_XL_64_Full_CNN |
| batch_size | 32 |
| patch_size | 16 |
| max_objects | 11 |
| image_height | 64 |
| image_width | 64 |
| num_epochs | 100 |
| warmup_epochs | 15 |
| early_stopping_patience | 15 |
| lr | 0.001 |
| encoder_embed_dim | 256 |
| decoder_embed_dim | 192 |
| max_len | 64 |
| in_out_channels | 3 |
| attn_dim | 128 |
| num_heads | 8 |
| mlp_size | 1024 |
| encoder_depth | 12 |
| decoder_depth | 8 |
| predictor_depth | 8 |
| predictor_embed_dim | 192 |
| num_preds | 5 |
| predictor_window_size | 5 |
| residual | true |
| use_masks | true |
| use_bboxes | false |
from model.oc_encoder import ObjectCentricEncoder
from model.oc_decoder import ObjectCentricDecoder
from model.ocvp import TransformerAutoEncoder
oc_encoder = ObjectCentricEncoder()
oc_decoder = ObjectCentricDecoder()
model = TransformerAutoEncoder(oc_encoder, oc_decoder).to(device)
# summary(model, input_size= rgbs.shape)
with torch.no_grad():
encoded_features = oc_encoder(rgbs, masks, coords)
print("Encoded features shape:", encoded_features.shape)
recons, loss = oc_decoder(encoded_features, rgbs)
print("Reconstructed output shape:", recons.shape)
Encoded features shape: torch.Size([32, 24, 11, 256]) Reconstructed output shape: torch.Size([32, 24, 3, 64, 64])
We can see that our cnn encoder is properly producing 11 embedding for all the objects in the scene, and also decoder reconstructs the input to the same size
- Loading Onject-Centric-AE model pre-trained checkpointsΒΆ
from utils.visualization import plot_images_vs_recons
path_AE = 'experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pth'
model = load_model(model, mode='AE_inference', path_AE = path_AE)
# Generate reconstructions
with torch.no_grad():
recons, _ = model(rgbs,masks)
# Plot 5 random sequences, each showing 8 original vs 8 reconstructed frames
plot_images_vs_recons(rgbs, recons)
- PSNR, SSIM, and LPIPS metrics analysis:ΒΆ
from utils.metrics import evaluate_metrics
evaluate_metrics(recons, rgbs)
PSNR Mean: 24.983428955078125
PSNR Framewise: tensor([25.6122, 25.3229, 24.9899, 24.8046, 24.5098, 24.4270, 24.4462, 24.5905,
24.6441, 24.6562, 24.6746, 24.7518, 24.8696, 24.9073, 24.9973, 25.0703,
25.1042, 25.1092, 25.1894, 25.2762, 25.3370, 25.4012, 25.4492, 25.4616],
device='cuda:0')
SSIM Mean: 0.7169415354728699
SSIM Framewise: tensor([0.7294, 0.7254, 0.7211, 0.7150, 0.7064, 0.7033, 0.7019, 0.7087, 0.7133,
0.7163, 0.7199, 0.7191, 0.7202, 0.7197, 0.7205, 0.7200, 0.7193, 0.7177,
0.7189, 0.7175, 0.7160, 0.7175, 0.7201, 0.7194], device='cuda:0')
LPIPS Mean: 0.24213504791259766
LPIPS Framewise: tensor([0.2474, 0.2425, 0.2369, 0.2456, 0.2531, 0.2534, 0.2601, 0.2444, 0.2392,
0.2403, 0.2391, 0.2362, 0.2336, 0.2404, 0.2380, 0.2475, 0.2408, 0.2410,
0.2409, 0.2388, 0.2419, 0.2406, 0.2374, 0.2321], device='cuda:0')
NOTEΒΆ
Not too bad!! Not the best! Could be better. We proceeded to the predictor training with this model
- Object-Centric Transformer-Predictor Module
from model.oc_predictor import ObjectCentricTransformerPredictor
from model.predictor_wrapper import PredictorWrapper
from model.ocvp import TransformerPredictor
oc_predictor = ObjectCentricTransformerPredictor()
predictor = PredictorWrapper(oc_predictor)
model = TransformerPredictor(oc_encoder, oc_decoder, predictor, mode='inference').to(device)
from utils.visualization import plot_predictor_images
path_AE = 'experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pth'
path_predictor = 'experiments/01_OC_Predictor_XL/checkpoints/best_01_OC_Predictor_XL.pth'
model = load_model(model, mode='inference', path_AE = path_AE, path_predictor=path_predictor)
# we Visualize 3 random sequences
idx = [0,3,7]
with torch.no_grad():
for i in idx:
# Get a single sequence
sequence_rgbs = rgbs[i:i+1] # This keeps the batch dimension
sequence_masks = {k: v[i:i+1] for k, v in masks.items()} # Handle masks dictionary
# Get predictions
encoded_features = oc_encoder(sequence_rgbs, sequence_masks)
preds, loss, input_range, target_range = model(sequence_rgbs, sequence_masks)
# Get reconstructions
recons, _ = model.decoder(preds)
# Get input and target ranges
start_input, end_input = input_range
start_target, end_target = target_range
# Extract relevant frames
input_images = sequence_rgbs[:, start_input:end_input]
target_images = sequence_rgbs[:, start_target:end_target]
# Visualize
plot_predictor_images(input_images, target_images, recons)
NOTEΒΆ
We can see that our OC-Predictor could actually produce much better future scene predictions, compared to the Holistic predictor module. This is due to the fact that in OC scene scenario, the model benefits from the object-frames extracted from masks/bboxes and through multi-head attention, we can actually capture the temporal relations between differnt objects in the scene.